package com.game.core.net.handler;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.util.ReferenceCountUtil;

import java.io.IOException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.protobuf.GeneratedMessage;
import com.kodgames.core.threadPool.OrderedThreadPoolExecutor;
import com.kodgames.core.net.udp.UdpConnectionHandler;
import com.kodgames.core.net.udp.udpkit.UdpConnection;
import com.game.core.net.common.RemoteNode;
import com.game.core.net.message.AbstractCustomizeMessage;
import com.game.core.net.message.InternalMessage;
import com.game.core.net.task.MessageTask;
import com.game.core.service.TransactionMgr;
import com.game.message.protocol.ProtocolsConfig;

@Sharable
public class MessageProcessor extends ChannelDuplexHandler implements UdpConnectionHandler
{
	private static final Logger logger = LoggerFactory.getLogger(MessageProcessor.class);
	
	private Map<InetSocketAddress, RemoteNode> remoteNodeMaps = new ConcurrentHashMap<InetSocketAddress, RemoteNode>();

	public static final int MESSAGE_SENDBUFFER_DEFAULTSIZE = 256;
	AbstractMessageInitializer messageInitializer;

	public MessageProcessor(AbstractMessageInitializer messageInitializer)
	{
		this.messageInitializer = messageInitializer;
	}

	private byte[] ByteBuf2ByteArray(ByteBuf buf)
	{
		byte[] array;
		final int length = buf.readableBytes();
		array = new byte[length];
		buf.getBytes(buf.readerIndex(), array, 0, length);
		return array;
	}
	
	private boolean isAmountScope(int protocolID)
	{
		switch (protocolID)
		{
		case -1:
			return false;
		}
		
		return true;
	}

	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
	{
		if (!ByteBuf.class.isAssignableFrom(msg.getClass()))
		{
			return;
		}
		
		try
		{
			ByteBuf buf = (ByteBuf)msg;
			int len = buf.readableBytes();
			//如果只有1字节，则为ping
			if (len == 1)
			{
				ctx.fireChannelRead(msg);
				return;
			}
			else if (len < 4+4)	//protocolID+callback
			{
				logger.error("MessageProcessor.channelRead found error msg len {}", len);
				return;
			}
			
			
			RemoteNode remoteNode;
			
			//需要区分tcp还是udp，如果是udp,RemoteNode属性不可靠，因为一个channel被多个线程使用
			//netty对于一个channel使用一个线程，但是udp的逻辑部分使用线程池加速消息派发
			if (NioDatagramChannel.class.isAssignableFrom(ctx.channel().getClass()))
			{
				//从首部取出remote和channelId
				//1字节channelId+4字节IP+4字节端口
				byte channelId = buf.readByte();
				byte[] ip = new byte[5];
				buf.readBytes(ip, 0, 4);
				String addr = String.format("%d.%d.%d.%d", ((short)ip[0] & 0xFF), ((short)ip[1] & 0xFF), ((short)ip[2] & 0xFF), ((short)ip[3] & 0xFF));	
				int port = buf.readInt();	
				
				//在MessageProcessor中需要维护RemoteNode的列表，用于查找该地址对应的RemoteNode
				InetSocketAddress remote = new InetSocketAddress(addr, port);
				remoteNode = remoteNodeMaps.get(remote);
			}
			else
			{
				remoteNode = ctx.channel().attr(RemoteNode.REMOTENODE).get();
			}
			
			int protocolID = buf.readInt();
			
			int callback = buf.readInt();
			InternalMessage corgiMsg = new InternalMessage();
			corgiMsg.setProtocolID(protocolID);
			corgiMsg.setCallback(callback);
			corgiMsg.setRemoteNode(remoteNode);
			
			if ((remoteNode != null) && isAmountScope(protocolID))
				remoteNode.incRecvAmount();

			logger.debug("channelRead: protocol id is {}:{}, ip address {}, role id {} len {} recvAmount {}", protocolID, Integer.toHexString(protocolID),
			        remoteNode.getAddress(), remoteNode.getRoleId(), len, remoteNode.getRecvAmount());
			
			if (messageInitializer.getMsgHandlerType() == AbstractMessageInitializer.CUSTOMIZEMESSAGE_HANDLER)
			{
				Class<?> msgClass = messageInitializer.getMessageClass(protocolID);
				AbstractCustomizeMessage message = msgClass == null ? null : ((AbstractCustomizeMessage) msgClass.newInstance());

				if (message != null)
				{
					message.decode(buf);
				}
				else
				{
					logger.error("channelRead found message null protocolID {}:{}", protocolID, Integer.toHexString(protocolID));
				}
				corgiMsg.setMessage(message);
			}
			else if (messageInitializer.getMsgHandlerType() == AbstractMessageInitializer.PROTOBUF_HANDLER)
			{
				Class<?> msgClass = messageInitializer.getMessageClass(protocolID);
				if (msgClass != null)
				{
					Method method = msgClass.getDeclaredMethod("parseFrom", byte[].class);
					if (method != null)
					{
						Object obj = method.invoke(null, ByteBuf2ByteArray(buf));
						corgiMsg.setMessage(obj);
					}
					else
					{
						throw new Exception("Illegal protocolID:" + protocolID + ", found class but Can't find parseFrom method.");
					}
				}
				else
				{
					throw new Exception("Illegal protocolID:" + protocolID + ", Can't find corresponding Protobuf class.");
				}
			}
			else if (messageInitializer.getMsgHandlerType() == AbstractMessageInitializer.BYTEARRAY_HANDLER)
			{
				corgiMsg.setMessage(ByteBuf2ByteArray(buf));
			}
			else
			{
				throw new Exception("Illegal MsgHandlerType, Please set correct msg type in MessageInitializer.");
			}

			BaseMessageHandler<?> handler = messageInitializer.getMessageHandler(protocolID);

			if (handler == null)
			{
				throw new Exception("Illegal protocolID:" + protocolID + ", Can't find corresponding handler.");
			}

			readMessage(handler, corgiMsg);
		}
		catch (Throwable e)
		{
			logger.error("channel read error: Throwable={}", e.toString());
			throw e;
		}
		finally
		{
			ReferenceCountUtil.release(msg);
		}
	}

	@Override
	public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception
	{
		if (DatagramPacket.class.isAssignableFrom(msg.getClass()) && NioDatagramChannel.class.isAssignableFrom(ctx.channel().getClass()))
		{
			// udpkit的command直接封装为DatagramPacket不需要逻辑层看到这些数据
			// 并且没有protocolID callback、压缩等封装
			ctx.write(msg, promise);
		}
		else if (msg instanceof ByteBuf)
		{
			//ping
			ctx.write(msg, promise);
		}
		else
		{
			// udp的channel数据
			byte[] buf = null;
			InternalMessage internalMsg = (InternalMessage) msg;
			int protocolID = internalMsg.getProtocolID();			
			int callback = internalMsg.getCallback();
			if (internalMsg.getMessage() instanceof AbstractCustomizeMessage)
			{
				AbstractCustomizeMessage message = (AbstractCustomizeMessage) internalMsg.getMessage();
				buf = message.encode();
			}
			else if (internalMsg.getMessage() instanceof GeneratedMessage)
			{
				GeneratedMessage protobufMsg = (GeneratedMessage) internalMsg.getMessage();
				buf = protobufMsg.toByteArray();
			}
			else if (internalMsg.getMessage() instanceof byte[])
			{
				buf = (byte[]) internalMsg.getMessage();
			}

			int length = buf == null ? 8 : buf.length + 8;
			ByteBuf outBuffer = ctx.alloc().buffer(length);
			try
			{
				RemoteNode node;
				if (NioDatagramChannel.class.isAssignableFrom(ctx.channel().getClass()))
				{
					node = internalMsg.getRemoteNode();
					if (node == null)
					{
						logger.error("write Udp Message but without RemoteNode Info then igored data");
						throw new IllegalArgumentException("UDPMessage without RemoteNode info");
					}
					
					outBuffer.writeByte(internalMsg.getChannelId());
					InetSocketAddress remote = (InetSocketAddress) node.getAddress();
					outBuffer.writeBytes(remote.getAddress().getAddress());
					outBuffer.writeInt(remote.getPort());
				}
				else
				{
					node = remoteNodeMaps.get(ctx.channel().remoteAddress());
				}
				
				if ((node != null) && isAmountScope(protocolID))
				{
					node.incSendAmount();
					logger.debug("write: protocol id is {}:{}, ip address {}, role id {} len {} sendAmount {}", protocolID, Integer.toHexString(protocolID),
							node.getAddress(), node.getRoleId(), buf.length, node.getSendAmount());
				}
				
				outBuffer.writeInt(protocolID);
				outBuffer.writeInt(callback);
				if (buf != null)
				{
					outBuffer.writeBytes(buf);
				}
				ctx.write(outBuffer, promise);
				
				
//				if (protocolID == ProtocolsConfig.P_WC_ACCOUNT_LOGIN_RES)
//				{
//					//后续启用新的加密key和解密key
//					node.getEncryptCoder().encryptUseNewKey();
//					node.getEncryptCoder().decryptUseNewKey();
//				}
			}
			catch (Exception e)
			{
				if (outBuffer != null)
					ReferenceCountUtil.release(outBuffer);
				logger.error("channel write error: Exception err={}",e.toString());
				throw e;
			}
			catch (Throwable t)
			{
				if (outBuffer != null)
					ReferenceCountUtil.release(outBuffer);
				logger.error("channel write error: Throwable={}", t.toString());
				throw t;
			}
		}
	}

	@Override
	public void channelActive(ChannelHandlerContext ctx) throws Exception
	{
		if (!NioDatagramChannel.class.isAssignableFrom(ctx.channel().getClass()))
		{
			logger.debug("ChannelActive channel local {} remote {} active {} create remoteNode", ctx.channel().localAddress(), ctx.channel().remoteAddress(), ctx.channel().isActive());
			// ctx.fireChannelActive();
			RemoteNode remoteNode = new RemoteNode();
			remoteNode.setAddress((InetSocketAddress) ctx.channel().remoteAddress());
			remoteNode.setChannel(ctx.channel());
			
			remoteNodeMaps.put(remoteNode.getAddress(), remoteNode);

			ctx.channel().attr(RemoteNode.REMOTENODE).set(remoteNode);
			InternalMessage p = new InternalMessage();
			p.setRemoteNode(remoteNode);
			readMessage(messageInitializer.getConnectionActiveHandler(), p);
		}
	}

	/**
	 * Calls {@link ChannelHandlerContext#fireChannelInactive()} to forward to the next
	 * {@link ChannelInboundHandler} in the {@link ChannelPipeline}. Sub-classes may override this
	 * method to change behavior.
	 */
	@Override
	public void channelInactive(ChannelHandlerContext ctx) throws Exception
	{
		if (!NioDatagramChannel.class.isAssignableFrom(ctx.channel().getClass()))
		{
			// ctx.fireChannelInactive();
			RemoteNode remoteNode = ctx.channel().attr(RemoteNode.REMOTENODE).get();
			InternalMessage p = new InternalMessage();
			p.setRemoteNode(remoteNode);
			
			readMessage(messageInitializer.getConnectionInactiveHandler(), p);
			
			remoteNodeMaps.remove(remoteNode.getAddress());
		}
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception
	{
		if (IOException.class.isAssignableFrom(cause.getClass()))
		{
			// 猜想是远程连接关
			logger.debug("exceptionCaught--Remote Peer:{} {}", ctx.channel().remoteAddress(), cause.toString());
		}
		else
		{
			logger.error("exceptionCaught {} --Exeption={}", ctx.channel().remoteAddress(), cause.toString());
		}
	}
	
	private boolean LogProtocolId(int protocolId)
	{
		switch (protocolId)
		{
		case 0:
			return false;
		}

		return true;		
	}

	public void readMessage(final BaseMessageHandler<?> handler, final InternalMessage protocol)
	{
		if (handler == null)
		{
			logger.warn("readMessage: handler is null. can't handle this message. address is " + 
					(protocol.getRemoteNode() != null ? protocol.getRemoteNode().getAddress() : "null") + ", protocolId is " + protocol.getProtocolID());
			return;
		}
		
		final OrderedThreadPoolExecutor beforeExecutor = messageInitializer.getBeforeExecutor();
		if (beforeExecutor != null)
		{
			beforeExecutor.execute(new MessageTask(messageInitializer, protocol, handler)
			{
				@Override
				public void run()
				{
					if (LogProtocolId(protocol.getProtocolID()))
						logger.debug("readMessage: run begin. key is " + this.getKey() + ", address is " + 
							(protocol.getRemoteNode() != null ? protocol.getRemoteNode().getAddress() : "null") + ", protocolId is " + protocol.getProtocolID() +
							", remote mode key is " + protocol.getRemoteNode().getInstanceId() + ", messageTask queue is " + this.getKeyName());
					try
					{
						// 注册当前线程，代表该线程进行了handler处理
						TransactionMgr.getInstance().registerThreadId(Thread.currentThread().getId(), protocol.getProtocolID(), protocol.getRemoteNode().getRoleId());
						handler.handleMessage(beforeExecutor, messageInitializer, protocol);
					}
					catch (Throwable e)
					{
						TransactionMgr.getInstance().rollback(Thread.currentThread().getId(), protocol.getProtocolID());
						MessageExceptionCatchHandler.getInstance().handleMessage(protocol.getRemoteNode(), protocol.getProtocolID(), protocol.getCallback(), protocol.getMessage());
						logger.error("Failed to hander message:{},Exception err={}", protocol.getMessage().getClass().getSimpleName(), e);
					}
					finally
					{
						TransactionMgr.getInstance().removeRecords(Thread.currentThread().getId());
						if (LogProtocolId(protocol.getProtocolID()))
							logger.debug("readMessage: run end. key is " + this.getKey() + ", protocolId is " + protocol.getProtocolID());
					}
				}
			});
		}
		else
		{
			if (LogProtocolId(protocol.getProtocolID()))
				logger.debug("readMessage: run begin 2." + ", protocolId is " + protocol.getProtocolID());
			try
			{
				// 注册当前线程，代表该线程进行了handler处理
				TransactionMgr.getInstance().registerThreadId(Thread.currentThread().getId(), protocol.getProtocolID(), protocol.getRemoteNode().getRoleId());
				handler.handleMessage(beforeExecutor, messageInitializer, protocol);
			}
			catch (Throwable e)
			{
				TransactionMgr.getInstance().rollback(Thread.currentThread().getId(), protocol.getProtocolID());
				MessageExceptionCatchHandler.getInstance().handleMessage(protocol.getRemoteNode(), protocol.getProtocolID(), protocol.getCallback(), protocol.getMessage());
				logger.error("Failed to hander message:{}", protocol.getMessage().getClass().getSimpleName(), e);
			}
			finally
			{
				TransactionMgr.getInstance().removeRecords(Thread.currentThread().getId());
			}
			if (LogProtocolId(protocol.getProtocolID()))
				logger.debug("readMessage: run end 2." + ", protocolId is " + protocol.getProtocolID());
		}
	}

	@Override
    public void connectionActive(UdpConnection cn)
    {
		RemoteNode node = new RemoteNode(cn);
		node.setChannel(cn.getNettyChannel());
		this.remoteNodeMaps.put(cn.getEndPoint(), node);

		InternalMessage p = new InternalMessage();
		p.setRemoteNode(node);
		readMessage(messageInitializer.getConnectionActiveHandler(), p);
    }

	@Override
    public void connectionInActive(UdpConnection cn)
    {
		RemoteNode node = remoteNodeMaps.get(cn.getEndPoint());
		if (node == null)
		{
			logger.warn("connectionInActive cn {} channel {} not found remoteNode", cn.getEndPoint(), cn.getNettyChannel().toString());
			return;
		}
		
		this.remoteNodeMaps.remove(cn.getEndPoint());

		InternalMessage p = new InternalMessage();
		p.setRemoteNode(node);
		readMessage(messageInitializer.getConnectionInactiveHandler(), p);	    
    }
}
